{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1cc20075-e2a4-4ff7-9a50-0c86b0aecd1f",
   "metadata": {},
   "source": [
    "# L4: Quantizing Models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd5a462a-32f9-4ef1-82f7-4f1753b60722",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note <code>(Kernel Starting)</code>:</b> This notebook takes about 30 seconds to be ready to use. You may start and watch the video while you wait.</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98633b04-5980-4285-8fcc-fe681862627b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# Use input resolution of the network\n",
    "input_shape = (1, 3, 1024, 2048)\n",
    "\n",
    "# Load 100 RGB images of urban scenes \n",
    "dataset = load_dataset(\"UrbanSyn/UrbanSyn\", \n",
    "                split=\"train\", \n",
    "                data_files=\"rgb/*_00*.png\")\n",
    "dataset = dataset.train_test_split(1)\n",
    "\n",
    "# Hold out for testing\n",
    "calibration_dataset = dataset[\"train\"]\n",
    "test_dataset = dataset[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4fbb9bf-bc86-4595-9db8-21fdccef8f43",
   "metadata": {},
   "outputs": [],
   "source": [
    "calibration_dataset[\"image\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99585e8c-2380-41ba-ae30-af4497a264d4",
   "metadata": {},
   "source": [
    "## Setup calibration/inference pipleline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "644e4224-4d52-4eea-a3f1-caa2a49abb00",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchvision import transforms\n",
    "\n",
    "# Convert the PIL image above to Torch Tensor\n",
    "preprocess = transforms.ToTensor()\n",
    "\n",
    "# Get a sample image in the test dataset\n",
    "test_sample_pil = test_dataset[0][\"image\"]\n",
    "test_sample = preprocess(test_sample_pil).unsqueeze(0) \n",
    "print(test_sample)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e132493-3d2e-4309-9e8d-ab4e9484501d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "def postprocess(output_tensor, input_image_pil):\n",
    "\n",
    "    # Upsample the output to the original size\n",
    "    output_tensor_upsampled = F.interpolate(\n",
    "        output_tensor, input_shape[2:], mode=\"bilinear\",\n",
    "    )\n",
    "\n",
    "    # Get top predicted class and convert to numpy\n",
    "    output_predictions = output_tensor_upsampled[0].argmax(0).byte().detach().numpy().astype(np.uint8)\n",
    "\n",
    "    # Overlay over original image\n",
    "    color_mask = Image.fromarray(output_predictions).convert(\"P\")\n",
    "\n",
    "    # Create an appropriate palette for the Cityscapes classes\n",
    "    palette = [\n",
    "        128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156,\n",
    "        190, 153, 153, 153, 153, 153, 250, 170, 30, 220, 220, 0,\n",
    "        107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,\n",
    "        255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100,\n",
    "        0, 0, 230, 119, 11, 32]\n",
    "    palette = palette + (256 * 3 - len(palette)) * [0]\n",
    "    color_mask.putpalette(palette)\n",
    "    out = Image.blend(input_image_pil, color_mask.convert(\"RGB\"), 0.5)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41be859e-9a2a-4767-805a-d0227b4e326b",
   "metadata": {},
   "source": [
    "## Setup model in floating point"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdc7775b-bdc2-4414-984e-bac3d290430d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qai_hub_models.models.ffnet_40s.model import FFNet40S\n",
    "model = FFNet40S.from_pretrained().model.eval()\n",
    "\n",
    "# Run sample output through the model\n",
    "test_output_fp32 = model(test_sample)\n",
    "test_output_fp32\n",
    "\n",
    "postprocess(test_output_fp32, test_sample_pil)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fa32510-cf8f-44e6-ae45-8579fe7c9ab9",
   "metadata": {},
   "source": [
    "## Prepare Quantized Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "442c8d18-98a7-4d77-9d48-bd5b09a01b3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from qai_hub_models.models._shared.ffnet_quantized.model import FFNET_AIMET_CONFIG\n",
    "from aimet_torch.batch_norm_fold import fold_all_batch_norms\n",
    "from aimet_torch.model_preparer import prepare_model\n",
    "from aimet_torch.quantsim import QuantizationSimModel\n",
    "\n",
    "# Prepare model for 8-bit quantization\n",
    "fold_all_batch_norms(model, [input_shape])\n",
    "model = prepare_model(model)\n",
    "\n",
    "# Setup quantization simulator\n",
    "quant_sim = QuantizationSimModel(\n",
    "    model,\n",
    "    quant_scheme=\"tf_enhanced\",\n",
    "    default_param_bw=8,              # Use bitwidth 8-bit\n",
    "    default_output_bw=8,\n",
    "    config_file=FFNET_AIMET_CONFIG,\n",
    "    dummy_input=torch.rand(input_shape),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59e8e919-09d4-4df9-bd07-5bc65c5dfde6",
   "metadata": {},
   "source": [
    "## Perform post training quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d325c9a-06a2-46ba-a8fa-e118e6512875",
   "metadata": {},
   "outputs": [],
   "source": [
    "size = 5  # Must be < 100\n",
    "\n",
    "def pass_calibration_data(sim_model: torch.nn.Module, args):\n",
    "    (dataset,) = args\n",
    "    with torch.no_grad():\n",
    "        for sample in dataset.select(range(size)):\n",
    "            pil_image = sample[\"image\"]\n",
    "            input_batch = preprocess(pil_image).unsqueeze(0)\n",
    "\n",
    "            # Feed sample through for calibration\n",
    "            sim_model(input_batch)\n",
    "\n",
    "# Run Post-Training Quantization (PTQ)\n",
    "quant_sim.compute_encodings(pass_calibration_data, [calibration_dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee33339c-d4b1-4279-bde2-afcbccdc0fba",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_output_int8 = quant_sim.model(test_sample)\n",
    "postprocess(test_output_int8, test_sample_pil)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56af1e43-cd39-4281-9328-16f801a65c75",
   "metadata": {},
   "source": [
    "## Run Quantized model on-device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1532257b-230b-49be-8452-0061b4ad4d26",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6ff; padding:15px; border-width:3px; border-color:#efe6ef; border-style:solid; border-radius:6px\"> 💻 &nbsp; <b>Access Utils File and Helper Functions:</b> To access the files for this notebook, 1) click on the <em>\"File\"</em> option on the top menu of the notebook and then 2) click on <em>\"Open\"</em>. For more help, please see the <em>\"Appendix - Tips and Help\"</em> Lesson.</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6331eeda-7a13-44e1-b44a-d00f492e99ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import qai_hub\n",
    "import qai_hub_models\n",
    "\n",
    "from utils import get_ai_hub_api_token\n",
    "ai_hub_api_token = get_ai_hub_api_token()\n",
    "\n",
    "!qai-hub configure --api_token $ai_hub_api_token"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a11d359a-7634-45b8-9f58-c8c6a5707c01",
   "metadata": {},
   "source": [
    "<p style=\"background-color:#fff6e4; padding:15px; border-width:3px; border-color:#f5ecda; border-style:solid; border-radius:6px\"> ⏳ <b>Note:</b> To spread the load across various devices, we are selecting a random device. Feel free to change it to any other device you prefer.</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d433c887-f023-422a-851b-7334ced70f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "devices = [\n",
    "    \"Samsung Galaxy S22 Ultra 5G\",\n",
    "    \"Samsung Galaxy S22 5G\",\n",
    "    \"Samsung Galaxy S22+ 5G\",\n",
    "    \"Samsung Galaxy Tab S8\",\n",
    "    \"Xiaomi 12\",\n",
    "    \"Xiaomi 12 Pro\",\n",
    "    \"Samsung Galaxy S22 5G\",\n",
    "    \"Samsung Galaxy S23\",\n",
    "    \"Samsung Galaxy S23+\",\n",
    "    \"Samsung Galaxy S23 Ultra\",\n",
    "    \"Samsung Galaxy S24\",\n",
    "    \"Samsung Galaxy S24 Ultra\",\n",
    "    \"Samsung Galaxy S24+\",\n",
    "]\n",
    "\n",
    "import random\n",
    "selected_device = random.choice(devices)\n",
    "print(selected_device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d1e42b6-58fa-4104-9f81-bb40e79798cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -m qai_hub_models.models.ffnet_40s_quantized.export -- --device \"$selected_device\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197a806e-2201-4f9d-9e6b-0ea431f9ce94",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd6fc865-45b5-4a25-be9c-ca934f9470c9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ad84f5-ed7f-40bb-a007-df0d3e693f3c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87f2d067-02dc-4bbd-bf64-3b7d28805d4b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42d3e5b6-f9a8-4009-93d6-1e1c99f0ca19",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
